Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile for Whisper #30949

Conversation

zhenglongjiepheonix
Copy link
Contributor

this PR adds torch.compile support for whisper model which is encoder-decoder architecture

Comment on lines -297 to -306
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here according to my understanding, key_states and value_states are once and for all computed based on encoder hidden states when we are doing cross-attention, so we can cache them ideally, but I think these are a little bit different from the existing caches because we don't need to update the cache in every generation step, we just have to do it once in the first step, so we could either try creating another Cache class and in that cases we need to pass in two caches(one for self attention and one for cross attention), even if we can create a cache class to wrap both, we still need to modify the current get_cache logic because it definitely will need more parameters to initiate the cache, and I don't know if the use case is general enough to create this new cache class. Or we can just initiate and update the cross attention kv cache within every layer like recurrent gemma, but this will require manually reset the cache between generations because the current logic in generation seems not considering the case when there is an inherent cache possessed by the model. I am personally for the latter solution, however I think both will need some changes in the generation utils file, and not sure which way might make dynamo or cudagraphs unhappy, what do you think is best @ArthurZucker

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding is correct: we compute the k/v states for the cross-attention once in the first forward pass, and then save them to cache. We would indeed want to cache the k/v states to avoid re-computing them at every step.

What do you think about the proposed design in this PR? #28931 (comment) Similar to your solution, it uses two cache's: one for the self-attention, and one for the cross-attention. Note that the cache specifics are out-of-date, given the recent changes to the cache API, but the high-level design remains valid!

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great if two StaticCache would work, but the issue is in that case we can not tell whether we are in the first generation step or not because the shape of cache now is always fixed to hold the maximum generation size, and use get_seq_length(to see if we have processed any tokens yet during tracing) in branch condition will indeed cause graph breaks

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my current solution is use two StaticCache and add another flag in attention layer to mark whether we are doing the first generation step, this will cause a recompile in the second step just like in llama and mistral but should work fine for the subsequent steps, but it's still nasty because it breaks the current stand-alone cache design, maybe I should indeed create another new OneShotCache class which hides the flag inside? cc @ArthurZucker
@gante @sanchit-gandhi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current solution utilizes a new OneShotCache, and a tuple of two caches is expected for conditional generation

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +1693 to +1699
if self.config.is_encoder_decoder:
# manually set another cache for cross attention
encoder_outputs = model_kwargs["encoder_outputs"][0]
model_kwargs["past_key_values"] = (
model_kwargs["past_key_values"],
self._get_cache("one_shot", encoder_outputs.shape[0], encoder_outputs.shape[1], '_cross_attn_cache')
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a better way to do this, because in the encoder-decoder scenario we need a tuple of two caches here according to the current design, but this seems hardcode and easy to break

Comment on lines +306 to +310
if is_cross_attention and ((isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache))
or isinstance(past_key_value, OneShotStaticCache) and past_key_value.query_cache_filled_status(self.layer_idx)):
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
need_update_cache = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we can always use OneShotStaticCache for kv cache here for simplicity, beccause using Dynamic Cache won't give us memory benefits on cross-atten kv cache

@zhenglongjiepheonix zhenglongjiepheonix changed the title [WIP] Add torch.compile for Whisper Add torch.compile for Whisper May 27, 2024
@ArthurZucker ArthurZucker mentioned this pull request Jun 3, 2024
6 tasks
@ArthurZucker ArthurZucker removed their request for review June 12, 2024 14:41
Copy link

github-actions bot commented Jul 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants